// Copyright 2025 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::min;
use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Formatter};
use std::num::NonZeroU64;
use std::sync::Arc;

use anyhow::anyhow;
use async_recursion::async_recursion;
use chrono::{MappedLocalTime, TimeZone};
use enum_as_inner::EnumAsInner;
use futures::TryStreamExt;
use iceberg::expr::Predicate as IcebergPredicate;
use itertools::Itertools;
use petgraph::{Directed, Graph};
use pgwire::pg_server::SessionId;
use risingwave_batch::error::BatchError;
use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
use risingwave_common::bail;
use risingwave_common::bitmap::{Bitmap, BitmapBuilder};
use risingwave_common::catalog::{Schema, TableDesc};
use risingwave_common::hash::table_distribution::TableDistribution;
use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping};
use risingwave_common::types::Timestamptz;
use risingwave_common::util::scan_range::ScanRange;
use risingwave_connector::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
use risingwave_connector::source::filesystem::opendal_source::{
    OpendalAzblob, OpendalGcs, OpendalS3,
};
use risingwave_connector::source::iceberg::{IcebergSplitEnumerator, IcebergTimeTravelInfo};
use risingwave_connector::source::kafka::KafkaSplitEnumerator;
use risingwave_connector::source::reader::reader::build_opendal_fs_list_for_batch;
use risingwave_connector::source::{
    ConnectorProperties, SourceEnumeratorContext, SplitEnumerator, SplitImpl,
};
use risingwave_pb::batch_plan::iceberg_scan_node::IcebergScanType;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::{ExchangeInfo, ScanRange as ScanRangeProto};
use risingwave_pb::plan_common::Field as PbField;
use risingwave_sqlparser::ast::AsOf;
use serde::ser::SerializeStruct;
use serde::Serialize;
use uuid::Uuid;

use super::SchedulerError;
use crate::catalog::catalog_service::CatalogReader;
use crate::catalog::TableId;
use crate::error::RwError;
use crate::optimizer::plan_node::generic::{GenericPlanRef, PhysicalPlanRef};
use crate::optimizer::plan_node::{
    BatchIcebergScan, BatchKafkaScan, BatchSource, PlanNodeId, PlanNodeType,
};
use crate::optimizer::property::Distribution;
use crate::optimizer::PlanRef;
use crate::scheduler::SchedulerResult;

#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct QueryId {
    pub id: String,
}

impl std::fmt::Display for QueryId {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "QueryId:{}", self.id)
    }
}

pub type StageId = u32;

// Root stage always has only one task.
pub const ROOT_TASK_ID: u64 = 0;
// Root task has only one output.
pub const ROOT_TASK_OUTPUT_ID: u64 = 0;
pub type TaskId = u64;

/// Generated by [`BatchPlanFragmenter`] and used in query execution graph.
#[derive(Clone, Debug)]
pub struct ExecutionPlanNode {
    pub plan_node_id: PlanNodeId,
    pub plan_node_type: PlanNodeType,
    pub node: NodeBody,
    pub schema: Vec<PbField>,

    pub children: Vec<Arc<ExecutionPlanNode>>,

    /// The stage id of the source of `BatchExchange`.
    /// Used to find `ExchangeSource` from scheduler when creating `PlanNode`.
    ///
    /// `None` when this node is not `BatchExchange`.
    pub source_stage_id: Option<StageId>,
}

impl Serialize for ExecutionPlanNode {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        let mut state = serializer.serialize_struct("QueryStage", 5)?;
        state.serialize_field("plan_node_id", &self.plan_node_id)?;
        state.serialize_field("plan_node_type", &self.plan_node_type)?;
        state.serialize_field("schema", &self.schema)?;
        state.serialize_field("children", &self.children)?;
        state.serialize_field("source_stage_id", &self.source_stage_id)?;
        state.end()
    }
}

impl TryFrom<PlanRef> for ExecutionPlanNode {
    type Error = SchedulerError;

    fn try_from(plan_node: PlanRef) -> Result<Self, Self::Error> {
        Ok(Self {
            plan_node_id: plan_node.plan_base().id(),
            plan_node_type: plan_node.node_type(),
            node: plan_node.try_to_batch_prost_body()?,
            children: vec![],
            schema: plan_node.schema().to_prost(),
            source_stage_id: None,
        })
    }
}

impl ExecutionPlanNode {
    pub fn node_type(&self) -> PlanNodeType {
        self.plan_node_type
    }
}

/// `BatchPlanFragmenter` splits a query plan into fragments.
pub struct BatchPlanFragmenter {
    query_id: QueryId,
    next_stage_id: StageId,
    worker_node_manager: WorkerNodeSelector,
    catalog_reader: CatalogReader,

    batch_parallelism: usize,
    timezone: String,

    stage_graph_builder: Option<StageGraphBuilder>,
    stage_graph: Option<StageGraph>,
}

impl Default for QueryId {
    fn default() -> Self {
        Self {
            id: Uuid::new_v4().to_string(),
        }
    }
}

impl BatchPlanFragmenter {
    pub fn new(
        worker_node_manager: WorkerNodeSelector,
        catalog_reader: CatalogReader,
        batch_parallelism: Option<NonZeroU64>,
        timezone: String,
        batch_node: PlanRef,
    ) -> SchedulerResult<Self> {
        // if batch_parallelism is None, it means no limit, we will use the available nodes count as
        // parallelism.
        // if batch_parallelism is Some(num), we will use the min(num, the available
        // nodes count) as parallelism.
        let batch_parallelism = if let Some(num) = batch_parallelism {
            // can be 0 if no available serving worker
            min(
                num.get() as usize,
                worker_node_manager.schedule_unit_count(),
            )
        } else {
            // can be 0 if no available serving worker
            worker_node_manager.schedule_unit_count()
        };

        let mut plan_fragmenter = Self {
            query_id: Default::default(),
            next_stage_id: 0,
            worker_node_manager,
            catalog_reader,
            batch_parallelism,
            timezone,
            stage_graph_builder: Some(StageGraphBuilder::new(batch_parallelism)),
            stage_graph: None,
        };
        plan_fragmenter.split_into_stage(batch_node)?;
        Ok(plan_fragmenter)
    }

    /// Split the plan node into each stages, based on exchange node.
    fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> {
        let root_stage = self.new_stage(
            batch_node,
            Some(Distribution::Single.to_prost(
                1,
                &self.catalog_reader,
                &self.worker_node_manager,
            )?),
        )?;
        self.stage_graph = Some(
            self.stage_graph_builder
                .take()
                .unwrap()
                .build(root_stage.id),
        );
        Ok(())
    }
}

/// The fragmented query generated by [`BatchPlanFragmenter`].
#[derive(Debug)]
#[cfg_attr(test, derive(Clone))]
pub struct Query {
    /// Query id should always be unique.
    pub query_id: QueryId,
    pub stage_graph: StageGraph,
}

impl Query {
    pub fn leaf_stages(&self) -> Vec<StageId> {
        let mut ret_leaf_stages = Vec::new();
        for stage_id in self.stage_graph.stages.keys() {
            if self
                .stage_graph
                .get_child_stages_unchecked(stage_id)
                .is_empty()
            {
                ret_leaf_stages.push(*stage_id);
            }
        }
        ret_leaf_stages
    }

    pub fn get_parents(&self, stage_id: &StageId) -> &HashSet<StageId> {
        self.stage_graph.parent_edges.get(stage_id).unwrap()
    }

    pub fn root_stage_id(&self) -> StageId {
        self.stage_graph.root_stage_id
    }

    pub fn query_id(&self) -> &QueryId {
        &self.query_id
    }

    pub fn stages_with_table_scan(&self) -> HashSet<StageId> {
        self.stage_graph
            .stages
            .iter()
            .filter_map(|(stage_id, stage_query)| {
                if stage_query.has_table_scan() {
                    Some(*stage_id)
                } else {
                    None
                }
            })
            .collect()
    }

    pub fn has_lookup_join_stage(&self) -> bool {
        self.stage_graph
            .stages
            .iter()
            .any(|(_stage_id, stage_query)| stage_query.has_lookup_join())
    }
}

#[derive(Debug, Clone)]
pub enum SourceFetchParameters {
    IcebergSpecificInfo(IcebergSpecificInfo),
    KafkaTimebound {
        lower: Option<i64>,
        upper: Option<i64>,
    },
    Empty,
}

#[derive(Debug, Clone)]
pub struct SourceFetchInfo {
    pub schema: Schema,
    /// These are user-configured connector properties.
    /// e.g. host, username, etc...
    pub connector: ConnectorProperties,
    /// These parameters are internally derived by the plan node.
    /// e.g. predicate pushdown for iceberg, timebound for kafka.
    pub fetch_parameters: SourceFetchParameters,
    pub as_of: Option<AsOf>,
}

#[derive(Debug, Clone)]
pub struct IcebergSpecificInfo {
    pub iceberg_scan_type: IcebergScanType,
    pub predicate: IcebergPredicate,
}

#[derive(Clone, Debug)]
pub enum SourceScanInfo {
    /// Split Info
    Incomplete(SourceFetchInfo),
    Complete(Vec<SplitImpl>),
}

impl SourceScanInfo {
    pub fn new(fetch_info: SourceFetchInfo) -> Self {
        Self::Incomplete(fetch_info)
    }

    pub async fn complete(
        self,
        batch_parallelism: usize,
        timezone: String,
    ) -> SchedulerResult<Self> {
        let fetch_info = match self {
            SourceScanInfo::Incomplete(fetch_info) => fetch_info,
            SourceScanInfo::Complete(_) => {
                unreachable!("Never call complete when SourceScanInfo is already complete")
            }
        };
        match (fetch_info.connector, fetch_info.fetch_parameters) {
            (
                ConnectorProperties::Kafka(prop),
                SourceFetchParameters::KafkaTimebound { lower, upper },
            ) => {
                let mut kafka_enumerator =
                    KafkaSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
                        .await?;
                let split_info = kafka_enumerator
                    .list_splits_batch(lower, upper)
                    .await?
                    .into_iter()
                    .map(SplitImpl::Kafka)
                    .collect_vec();

                Ok(SourceScanInfo::Complete(split_info))
            }
            (ConnectorProperties::OpendalS3(prop), SourceFetchParameters::Empty) => {
                let lister: OpendalEnumerator<OpendalS3> =
                    OpendalEnumerator::new_s3_source(prop.s3_properties, prop.assume_role)?;
                let stream = build_opendal_fs_list_for_batch(lister);

                let batch_res: Vec<_> = stream.try_collect().await?;
                let res = batch_res
                    .into_iter()
                    .map(SplitImpl::OpendalS3)
                    .collect_vec();

                Ok(SourceScanInfo::Complete(res))
            }
            (ConnectorProperties::Gcs(prop), SourceFetchParameters::Empty) => {
                let lister: OpendalEnumerator<OpendalGcs> =
                    OpendalEnumerator::new_gcs_source(*prop)?;
                let stream = build_opendal_fs_list_for_batch(lister);
                let batch_res: Vec<_> = stream.try_collect().await?;
                let res = batch_res.into_iter().map(SplitImpl::Gcs).collect_vec();

                Ok(SourceScanInfo::Complete(res))
            }
            (ConnectorProperties::Azblob(prop), SourceFetchParameters::Empty) => {
                let lister: OpendalEnumerator<OpendalAzblob> =
                    OpendalEnumerator::new_azblob_source(*prop)?;
                let stream = build_opendal_fs_list_for_batch(lister);
                let batch_res: Vec<_> = stream.try_collect().await?;
                let res = batch_res.into_iter().map(SplitImpl::Azblob).collect_vec();

                Ok(SourceScanInfo::Complete(res))
            }
            (
                ConnectorProperties::Iceberg(prop),
                SourceFetchParameters::IcebergSpecificInfo(iceberg_specific_info),
            ) => {
                let iceberg_enumerator =
                    IcebergSplitEnumerator::new(*prop, SourceEnumeratorContext::dummy().into())
                        .await?;

                let time_travel_info = match fetch_info.as_of {
                    Some(AsOf::VersionNum(v)) => Some(IcebergTimeTravelInfo::Version(v)),
                    Some(AsOf::TimestampNum(ts)) => {
                        Some(IcebergTimeTravelInfo::TimestampMs(ts * 1000))
                    }
                    Some(AsOf::VersionString(_)) => {
                        bail!("Unsupported version string in iceberg time travel")
                    }
                    Some(AsOf::TimestampString(ts)) => {
                        let date_time = speedate::DateTime::parse_str_rfc3339(&ts)
                            .map_err(|_e| anyhow!("fail to parse timestamp"))?;
                        let timestamp = if date_time.time.tz_offset.is_none() {
                            // If the input does not specify a time zone, use the time zone set by the "SET TIME ZONE" command.
                            let tz =
                                Timestamptz::lookup_time_zone(&timezone).map_err(|e| anyhow!(e))?;
                            match tz.with_ymd_and_hms(
                                date_time.date.year.into(),
                                date_time.date.month.into(),
                                date_time.date.day.into(),
                                date_time.time.hour.into(),
                                date_time.time.minute.into(),
                                date_time.time.second.into(),
                            ) {
                                MappedLocalTime::Single(d) => Ok(d.timestamp()),
                                MappedLocalTime::Ambiguous(_, _) | MappedLocalTime::None => {
                                    Err(anyhow!(format!(
                                        "failed to parse the timestamp {ts} with the specified time zone {tz}"
                                    )))
                                }
                            }?
                        } else {
                            date_time.timestamp_tz()
                        };

                        Some(IcebergTimeTravelInfo::TimestampMs(
                            timestamp * 1000 + date_time.time.microsecond as i64 / 1000,
                        ))
                    }
                    Some(AsOf::ProcessTime) | Some(AsOf::ProcessTimeWithInterval(_)) => {
                        unreachable!()
                    }
                    None => None,
                };

                let split_info = iceberg_enumerator
                    .list_splits_batch(
                        fetch_info.schema,
                        time_travel_info,
                        batch_parallelism,
                        iceberg_specific_info.iceberg_scan_type,
                        iceberg_specific_info.predicate,
                    )
                    .await?
                    .into_iter()
                    .map(SplitImpl::Iceberg)
                    .collect_vec();

                Ok(SourceScanInfo::Complete(split_info))
            }
            _ => Err(SchedulerError::Internal(anyhow!(
                "Unsupported to query directly from this source"
            ))),
        }
    }

    pub fn split_info(&self) -> SchedulerResult<&Vec<SplitImpl>> {
        match self {
            Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!(
                "Should not get split info from incomplete source scan info"
            ))),
            Self::Complete(split_info) => Ok(split_info),
        }
    }
}

#[derive(Clone, Debug)]
pub struct TableScanInfo {
    /// The name of the table to scan.
    name: String,

    /// Indicates the table partitions to be read by scan tasks. Unnecessary partitions are already
    /// pruned.
    ///
    /// For singleton table, this field is still `Some` and only contains a single partition with
    /// full vnode bitmap, since we need to know where to schedule the singleton scan task.
    ///
    /// `None` iff the table is a system table.
    partitions: Option<HashMap<WorkerSlotId, TablePartitionInfo>>,
}

impl TableScanInfo {
    /// For normal tables, `partitions` should always be `Some`.
    pub fn new(name: String, partitions: HashMap<WorkerSlotId, TablePartitionInfo>) -> Self {
        Self {
            name,
            partitions: Some(partitions),
        }
    }

    /// For system table, there's no partition info.
    pub fn system_table(name: String) -> Self {
        Self {
            name,
            partitions: None,
        }
    }

    pub fn name(&self) -> &str {
        self.name.as_ref()
    }

    pub fn partitions(&self) -> Option<&HashMap<WorkerSlotId, TablePartitionInfo>> {
        self.partitions.as_ref()
    }
}

#[derive(Clone, Debug)]
pub struct TablePartitionInfo {
    pub vnode_bitmap: Bitmap,
    pub scan_ranges: Vec<ScanRangeProto>,
}

#[derive(Clone, Debug, EnumAsInner)]
pub enum PartitionInfo {
    Table(TablePartitionInfo),
    Source(Vec<SplitImpl>),
    File(Vec<String>),
}

#[derive(Clone, Debug)]
pub struct FileScanInfo {
    pub file_location: Vec<String>,
}

/// Fragment part of `Query`.
#[derive(Clone)]
pub struct QueryStage {
    pub query_id: QueryId,
    pub id: StageId,
    pub root: Arc<ExecutionPlanNode>,
    pub exchange_info: Option<ExchangeInfo>,
    pub parallelism: Option<u32>,
    /// Indicates whether this stage contains a table scan node and the table's information if so.
    pub table_scan_info: Option<TableScanInfo>,
    pub source_info: Option<SourceScanInfo>,
    pub file_scan_info: Option<FileScanInfo>,
    pub has_lookup_join: bool,
    pub dml_table_id: Option<TableId>,
    pub session_id: SessionId,
    pub batch_enable_distributed_dml: bool,

    /// Used to generate exchange information when complete source scan information.
    children_exchange_distribution: Option<HashMap<StageId, Distribution>>,
}

impl QueryStage {
    /// If true, this stage contains table scan executor that creates
    /// Hummock iterators to read data from table. The iterator is initialized during
    /// the executor building process on the batch execution engine.
    pub fn has_table_scan(&self) -> bool {
        self.table_scan_info.is_some()
    }

    /// If true, this stage contains lookup join executor.
    /// We need to delay epoch unpin util the end of the query.
    pub fn has_lookup_join(&self) -> bool {
        self.has_lookup_join
    }

    pub fn clone_with_exchange_info(
        &self,
        exchange_info: Option<ExchangeInfo>,
        parallelism: Option<u32>,
    ) -> Self {
        if let Some(exchange_info) = exchange_info {
            return Self {
                query_id: self.query_id.clone(),
                id: self.id,
                root: self.root.clone(),
                exchange_info: Some(exchange_info),
                parallelism,
                table_scan_info: self.table_scan_info.clone(),
                source_info: self.source_info.clone(),
                file_scan_info: self.file_scan_info.clone(),
                has_lookup_join: self.has_lookup_join,
                dml_table_id: self.dml_table_id,
                session_id: self.session_id,
                batch_enable_distributed_dml: self.batch_enable_distributed_dml,
                children_exchange_distribution: self.children_exchange_distribution.clone(),
            };
        }
        self.clone()
    }

    pub fn clone_with_exchange_info_and_complete_source_info(
        &self,
        exchange_info: Option<ExchangeInfo>,
        source_info: SourceScanInfo,
        task_parallelism: u32,
    ) -> Self {
        assert!(matches!(source_info, SourceScanInfo::Complete(_)));
        let exchange_info = if let Some(exchange_info) = exchange_info {
            Some(exchange_info)
        } else {
            self.exchange_info.clone()
        };
        Self {
            query_id: self.query_id.clone(),
            id: self.id,
            root: self.root.clone(),
            exchange_info,
            parallelism: Some(task_parallelism),
            table_scan_info: self.table_scan_info.clone(),
            source_info: Some(source_info),
            file_scan_info: self.file_scan_info.clone(),
            has_lookup_join: self.has_lookup_join,
            dml_table_id: self.dml_table_id,
            session_id: self.session_id,
            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
            children_exchange_distribution: None,
        }
    }
}

impl Debug for QueryStage {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("QueryStage")
            .field("id", &self.id)
            .field("parallelism", &self.parallelism)
            .field("exchange_info", &self.exchange_info)
            .field("has_table_scan", &self.has_table_scan())
            .finish()
    }
}

impl Serialize for QueryStage {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        let mut state = serializer.serialize_struct("QueryStage", 3)?;
        state.serialize_field("root", &self.root)?;
        state.serialize_field("parallelism", &self.parallelism)?;
        state.serialize_field("exchange_info", &self.exchange_info)?;
        state.end()
    }
}

pub type QueryStageRef = Arc<QueryStage>;

struct QueryStageBuilder {
    query_id: QueryId,
    id: StageId,
    root: Option<Arc<ExecutionPlanNode>>,
    parallelism: Option<u32>,
    exchange_info: Option<ExchangeInfo>,

    children_stages: Vec<QueryStageRef>,
    /// See also [`QueryStage::table_scan_info`].
    table_scan_info: Option<TableScanInfo>,
    source_info: Option<SourceScanInfo>,
    file_scan_file: Option<FileScanInfo>,
    has_lookup_join: bool,
    dml_table_id: Option<TableId>,
    session_id: SessionId,
    batch_enable_distributed_dml: bool,

    children_exchange_distribution: HashMap<StageId, Distribution>,
}

impl QueryStageBuilder {
    #[allow(clippy::too_many_arguments)]
    fn new(
        id: StageId,
        query_id: QueryId,
        parallelism: Option<u32>,
        exchange_info: Option<ExchangeInfo>,
        table_scan_info: Option<TableScanInfo>,
        source_info: Option<SourceScanInfo>,
        file_scan_file: Option<FileScanInfo>,
        has_lookup_join: bool,
        dml_table_id: Option<TableId>,
        session_id: SessionId,
        batch_enable_distributed_dml: bool,
    ) -> Self {
        Self {
            query_id,
            id,
            root: None,
            parallelism,
            exchange_info,
            children_stages: vec![],
            table_scan_info,
            source_info,
            file_scan_file,
            has_lookup_join,
            dml_table_id,
            session_id,
            batch_enable_distributed_dml,
            children_exchange_distribution: HashMap::new(),
        }
    }

    fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef {
        let children_exchange_distribution = if self.parallelism.is_none() {
            Some(self.children_exchange_distribution)
        } else {
            None
        };
        let stage = Arc::new(QueryStage {
            query_id: self.query_id,
            id: self.id,
            root: self.root.unwrap(),
            exchange_info: self.exchange_info,
            parallelism: self.parallelism,
            table_scan_info: self.table_scan_info,
            source_info: self.source_info,
            file_scan_info: self.file_scan_file,
            has_lookup_join: self.has_lookup_join,
            dml_table_id: self.dml_table_id,
            session_id: self.session_id,
            batch_enable_distributed_dml: self.batch_enable_distributed_dml,
            children_exchange_distribution,
        });

        stage_graph_builder.add_node(stage.clone());
        for child_stage in self.children_stages {
            stage_graph_builder.link_to_child(self.id, child_stage.id);
        }
        stage
    }
}

/// Maintains how each stage are connected.
#[derive(Debug, Serialize)]
#[cfg_attr(test, derive(Clone))]
pub struct StageGraph {
    pub root_stage_id: StageId,
    pub stages: HashMap<StageId, QueryStageRef>,
    /// Traverse from top to down. Used in split plan into stages.
    child_edges: HashMap<StageId, HashSet<StageId>>,
    /// Traverse from down to top. Used in schedule each stage.
    parent_edges: HashMap<StageId, HashSet<StageId>>,

    batch_parallelism: usize,
}

impl StageGraph {
    pub fn get_child_stages_unchecked(&self, stage_id: &StageId) -> &HashSet<StageId> {
        self.child_edges.get(stage_id).unwrap()
    }

    pub fn get_child_stages(&self, stage_id: &StageId) -> Option<&HashSet<StageId>> {
        self.child_edges.get(stage_id)
    }

    /// Returns stage ids in topology order, s.t. child stage always appears before its parent.
    pub fn stage_ids_by_topo_order(&self) -> impl Iterator<Item = StageId> {
        let mut stack = Vec::with_capacity(self.stages.len());
        stack.push(self.root_stage_id);
        let mut ret = Vec::with_capacity(self.stages.len());
        let mut existing = HashSet::with_capacity(self.stages.len());

        while let Some(s) = stack.pop() {
            if !existing.contains(&s) {
                ret.push(s);
                existing.insert(s);
                stack.extend(&self.child_edges[&s]);
            }
        }

        ret.into_iter().rev()
    }

    async fn complete(
        self,
        catalog_reader: &CatalogReader,
        worker_node_manager: &WorkerNodeSelector,
        timezone: String,
    ) -> SchedulerResult<StageGraph> {
        let mut complete_stages = HashMap::new();
        self.complete_stage(
            self.stages.get(&self.root_stage_id).unwrap().clone(),
            None,
            &mut complete_stages,
            catalog_reader,
            worker_node_manager,
            timezone,
        )
        .await?;
        Ok(StageGraph {
            root_stage_id: self.root_stage_id,
            stages: complete_stages,
            child_edges: self.child_edges,
            parent_edges: self.parent_edges,
            batch_parallelism: self.batch_parallelism,
        })
    }

    #[async_recursion]
    async fn complete_stage(
        &self,
        stage: QueryStageRef,
        exchange_info: Option<ExchangeInfo>,
        complete_stages: &mut HashMap<StageId, QueryStageRef>,
        catalog_reader: &CatalogReader,
        worker_node_manager: &WorkerNodeSelector,
        timezone: String,
    ) -> SchedulerResult<()> {
        let parallelism = if stage.parallelism.is_some() {
            // If the stage has parallelism, it means it's a complete stage.
            complete_stages.insert(
                stage.id,
                Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)),
            );
            None
        } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) {
            let complete_source_info = stage
                .source_info
                .as_ref()
                .unwrap()
                .clone()
                .complete(self.batch_parallelism, timezone.to_owned())
                .await?;

            // For batch reading file source, the number of files involved is typically large.
            // In order to avoid generating a task for each file, the parallelism of tasks is limited here.
            // The minimum `task_parallelism` is 1. Additionally, `task_parallelism`
            // must be greater than the number of files to read. Therefore, we first take the
            // minimum of the number of files and (self.batch_parallelism / 2). If the number of
            // files is 0, we set task_parallelism to 1.

            let task_parallelism = match &stage.source_info {
                Some(SourceScanInfo::Incomplete(source_fetch_info)) => {
                    match source_fetch_info.connector {
                        ConnectorProperties::Gcs(_)
                        | ConnectorProperties::OpendalS3(_)
                        | ConnectorProperties::Azblob(_) => (min(
                            complete_source_info.split_info().unwrap().len() as u32,
                            (self.batch_parallelism / 2) as u32,
                        ))
                        .max(1),
                        _ => complete_source_info.split_info().unwrap().len() as u32,
                    }
                }
                _ => unreachable!(),
            };
            // For file source batch read, all the files  to be read are divide into several parts to prevent the task from taking up too many resources.
            // todo(wcy-fdu): Currently it will be divided into half of batch_parallelism groups, and this will be changed to configurable later.
            let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info(
                exchange_info,
                complete_source_info,
                task_parallelism,
            ));
            let parallelism = complete_stage.parallelism;
            complete_stages.insert(stage.id, complete_stage);
            parallelism
        } else {
            assert!(stage.file_scan_info.is_some());
            let parallelism = min(
                self.batch_parallelism / 2,
                stage.file_scan_info.as_ref().unwrap().file_location.len(),
            );
            complete_stages.insert(
                stage.id,
                Arc::new(stage.clone_with_exchange_info(exchange_info, Some(parallelism as u32))),
            );
            None
        };

        for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) {
            let exchange_info = if let Some(parallelism) = parallelism {
                let exchange_distribution = stage
                    .children_exchange_distribution
                    .as_ref()
                    .unwrap()
                    .get(child_stage_id)
                    .expect("Exchange distribution is not consistent with the stage graph");
                Some(exchange_distribution.to_prost(
                    parallelism,
                    catalog_reader,
                    worker_node_manager,
                )?)
            } else {
                None
            };
            self.complete_stage(
                self.stages.get(child_stage_id).unwrap().clone(),
                exchange_info,
                complete_stages,
                catalog_reader,
                worker_node_manager,
                timezone.to_owned(),
            )
            .await?;
        }

        Ok(())
    }

    /// Converts the `StageGraph` into a `petgraph::graph::Graph<String, String>`.
    pub fn to_petgraph(&self) -> Graph<String, String, Directed> {
        let mut graph = Graph::<String, String, Directed>::new();

        let mut node_indices = HashMap::new();

        // Add all stages as nodes
        for (&stage_id, stage_ref) in self.stages.iter().sorted_by_key(|(id, _)| **id) {
            let node_label = format!("Stage {}: {:?}", stage_id, stage_ref);
            let node_index = graph.add_node(node_label);
            node_indices.insert(stage_id, node_index);
        }

        // Add edges between stages based on child_edges
        for (&parent_id, children) in &self.child_edges {
            if let Some(&parent_index) = node_indices.get(&parent_id) {
                for &child_id in children {
                    if let Some(&child_index) = node_indices.get(&child_id) {
                        // Add an edge from parent to child
                        graph.add_edge(parent_index, child_index, "".to_owned());
                    }
                }
            }
        }

        graph
    }
}

struct StageGraphBuilder {
    stages: HashMap<StageId, QueryStageRef>,
    child_edges: HashMap<StageId, HashSet<StageId>>,
    parent_edges: HashMap<StageId, HashSet<StageId>>,
    batch_parallelism: usize,
}

impl StageGraphBuilder {
    pub fn new(batch_parallelism: usize) -> Self {
        Self {
            stages: HashMap::new(),
            child_edges: HashMap::new(),
            parent_edges: HashMap::new(),
            batch_parallelism,
        }
    }

    pub fn build(self, root_stage_id: StageId) -> StageGraph {
        StageGraph {
            root_stage_id,
            stages: self.stages,
            child_edges: self.child_edges,
            parent_edges: self.parent_edges,
            batch_parallelism: self.batch_parallelism,
        }
    }

    /// Link parent stage and child stage. Maintain the mappings of parent -> child and child ->
    /// parent.
    pub fn link_to_child(&mut self, parent_id: StageId, child_id: StageId) {
        self.child_edges
            .get_mut(&parent_id)
            .unwrap()
            .insert(child_id);
        self.parent_edges
            .get_mut(&child_id)
            .unwrap()
            .insert(parent_id);
    }

    pub fn add_node(&mut self, stage: QueryStageRef) {
        // Insert here so that left/root stages also has linkage.
        self.child_edges.insert(stage.id, HashSet::new());
        self.parent_edges.insert(stage.id, HashSet::new());
        self.stages.insert(stage.id, stage);
    }
}

impl BatchPlanFragmenter {
    /// After split, the `stage_graph` in the framenter may has the stage with incomplete source
    /// info, we need to fetch the source info to complete the stage in this function.
    /// Why separate this two step(`split()` and `generate_complete_query()`)?
    /// The step of fetching source info is a async operation so that we can't do it in the split
    /// step.
    pub async fn generate_complete_query(self) -> SchedulerResult<Query> {
        let stage_graph = self.stage_graph.unwrap();
        let new_stage_graph = stage_graph
            .complete(
                &self.catalog_reader,
                &self.worker_node_manager,
                self.timezone.to_owned(),
            )
            .await?;
        Ok(Query {
            query_id: self.query_id,
            stage_graph: new_stage_graph,
        })
    }

    fn new_stage(
        &mut self,
        root: PlanRef,
        exchange_info: Option<ExchangeInfo>,
    ) -> SchedulerResult<QueryStageRef> {
        let next_stage_id = self.next_stage_id;
        self.next_stage_id += 1;

        let mut table_scan_info = self.collect_stage_table_scan(root.clone())?;
        // For current implementation, we can guarantee that each stage has only one table
        // scan(except System table) or one source.
        let source_info = if table_scan_info.is_none() {
            Self::collect_stage_source(root.clone())?
        } else {
            None
        };

        let file_scan_info = if table_scan_info.is_none() && source_info.is_none() {
            Self::collect_stage_file_scan(root.clone())?
        } else {
            None
        };

        let mut has_lookup_join = false;
        let parallelism = match root.distribution() {
            Distribution::Single => {
                if let Some(info) = &mut table_scan_info {
                    if let Some(partitions) = &mut info.partitions {
                        if partitions.len() != 1 {
                            // This is rare case, but it's possible on the internal state of the
                            // Source operator.
                            tracing::warn!(
                                "The stage has single distribution, but contains a scan of table `{}` with {} partitions. A single random worker will be assigned",
                                info.name,
                                partitions.len()
                            );

                            *partitions = partitions
                                .drain()
                                .take(1)
                                .update(|(_, info)| {
                                    info.vnode_bitmap = Bitmap::ones(info.vnode_bitmap.len());
                                })
                                .collect();
                        }
                    } else {
                        // System table
                    }
                } else if source_info.is_some() {
                    return Err(SchedulerError::Internal(anyhow!(
                        "The stage has single distribution, but contains a source operator"
                    )));
                }
                1
            }
            _ => {
                if let Some(table_scan_info) = &table_scan_info {
                    table_scan_info
                        .partitions
                        .as_ref()
                        .map(|m| m.len())
                        .unwrap_or(1)
                } else if let Some(lookup_join_parallelism) =
                    self.collect_stage_lookup_join_parallelism(root.clone())?
                {
                    has_lookup_join = true;
                    lookup_join_parallelism
                } else if source_info.is_some() {
                    0
                } else if file_scan_info.is_some() {
                    1
                } else {
                    self.batch_parallelism
                }
            }
        };
        if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 {
            return Err(BatchError::EmptyWorkerNodes.into());
        }
        let parallelism = if parallelism == 0 {
            None
        } else {
            Some(parallelism as u32)
        };
        let dml_table_id = Self::collect_dml_table_id(&root);
        let mut builder = QueryStageBuilder::new(
            next_stage_id,
            self.query_id.clone(),
            parallelism,
            exchange_info,
            table_scan_info,
            source_info,
            file_scan_info,
            has_lookup_join,
            dml_table_id,
            root.ctx().session_ctx().session_id(),
            root.ctx()
                .session_ctx()
                .config()
                .batch_enable_distributed_dml(),
        );

        self.visit_node(root, &mut builder, None)?;

        Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap()))
    }

    fn visit_node(
        &mut self,
        node: PlanRef,
        builder: &mut QueryStageBuilder,
        parent_exec_node: Option<&mut ExecutionPlanNode>,
    ) -> SchedulerResult<()> {
        match node.node_type() {
            PlanNodeType::BatchExchange => {
                self.visit_exchange(node.clone(), builder, parent_exec_node)?;
            }
            _ => {
                let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;

                for child in node.inputs() {
                    self.visit_node(child, builder, Some(&mut execution_plan_node))?;
                }

                if let Some(parent) = parent_exec_node {
                    parent.children.push(Arc::new(execution_plan_node));
                } else {
                    builder.root = Some(Arc::new(execution_plan_node));
                }
            }
        }
        Ok(())
    }

    fn visit_exchange(
        &mut self,
        node: PlanRef,
        builder: &mut QueryStageBuilder,
        parent_exec_node: Option<&mut ExecutionPlanNode>,
    ) -> SchedulerResult<()> {
        let mut execution_plan_node = ExecutionPlanNode::try_from(node.clone())?;
        let child_exchange_info = if let Some(parallelism) = builder.parallelism {
            Some(node.distribution().to_prost(
                parallelism,
                &self.catalog_reader,
                &self.worker_node_manager,
            )?)
        } else {
            None
        };
        let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?;
        execution_plan_node.source_stage_id = Some(child_stage.id);
        if builder.parallelism.is_none() {
            builder
                .children_exchange_distribution
                .insert(child_stage.id, node.distribution().clone());
        }

        if let Some(parent) = parent_exec_node {
            parent.children.push(Arc::new(execution_plan_node));
        } else {
            builder.root = Some(Arc::new(execution_plan_node));
        }

        builder.children_stages.push(child_stage);
        Ok(())
    }

    /// Check whether this stage contains a source node.
    /// If so, use  `SplitEnumeratorImpl` to get the split info from exteneral source.
    ///
    /// For current implementation, we can guarantee that each stage has only one source.
    fn collect_stage_source(node: PlanRef) -> SchedulerResult<Option<SourceScanInfo>> {
        if node.node_type() == PlanNodeType::BatchExchange {
            // Do not visit next stage.
            return Ok(None);
        }

        if let Some(batch_kafka_node) = node.as_batch_kafka_scan() {
            let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node;
            let source_catalog = batch_kafka_scan.source_catalog();
            if let Some(source_catalog) = source_catalog {
                let property =
                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
                let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value();
                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
                    schema: batch_kafka_scan.base.schema().clone(),
                    connector: property,
                    fetch_parameters: SourceFetchParameters::KafkaTimebound {
                        lower: timestamp_bound.0,
                        upper: timestamp_bound.1,
                    },
                    as_of: None,
                })));
            }
        } else if let Some(batch_iceberg_scan) = node.as_batch_iceberg_scan() {
            let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan;
            let source_catalog = batch_iceberg_scan.source_catalog();
            if let Some(source_catalog) = source_catalog {
                let property =
                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
                let as_of = batch_iceberg_scan.as_of();
                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
                    schema: batch_iceberg_scan.base.schema().clone(),
                    connector: property,
                    fetch_parameters: SourceFetchParameters::IcebergSpecificInfo(
                        IcebergSpecificInfo {
                            predicate: batch_iceberg_scan.predicate.clone(),
                            iceberg_scan_type: batch_iceberg_scan.iceberg_scan_type(),
                        },
                    ),
                    as_of,
                })));
            }
        } else if let Some(source_node) = node.as_batch_source() {
            // TODO: use specific batch operator instead of batch source.
            let source_node: &BatchSource = source_node;
            let source_catalog = source_node.source_catalog();
            if let Some(source_catalog) = source_catalog {
                let property =
                    ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?;
                let as_of = source_node.as_of();
                return Ok(Some(SourceScanInfo::new(SourceFetchInfo {
                    schema: source_node.base.schema().clone(),
                    connector: property,
                    fetch_parameters: SourceFetchParameters::Empty,
                    as_of,
                })));
            }
        }

        node.inputs()
            .into_iter()
            .find_map(|n| Self::collect_stage_source(n).transpose())
            .transpose()
    }

    fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult<Option<FileScanInfo>> {
        if node.node_type() == PlanNodeType::BatchExchange {
            // Do not visit next stage.
            return Ok(None);
        }

        if let Some(batch_file_scan) = node.as_batch_file_scan() {
            return Ok(Some(FileScanInfo {
                file_location: batch_file_scan.core.file_location().clone(),
            }));
        }

        node.inputs()
            .into_iter()
            .find_map(|n| Self::collect_stage_file_scan(n).transpose())
            .transpose()
    }

    /// Check whether this stage contains a table scan node and the table's information if so.
    ///
    /// If there are multiple scan nodes in this stage, they must have the same distribution, but
    /// maybe different vnodes partition. We just use the same partition for all the scan nodes.
    fn collect_stage_table_scan(&self, node: PlanRef) -> SchedulerResult<Option<TableScanInfo>> {
        let build_table_scan_info = |name, table_desc: &TableDesc, scan_range| {
            let table_catalog = self
                .catalog_reader
                .read_guard()
                .get_any_table_by_id(&table_desc.table_id)
                .cloned()
                .map_err(RwError::from)?;
            let vnode_mapping = self
                .worker_node_manager
                .fragment_mapping(table_catalog.fragment_id)?;
            let partitions = derive_partitions(scan_range, table_desc, &vnode_mapping)?;
            let info = TableScanInfo::new(name, partitions);
            Ok(Some(info))
        };
        if node.node_type() == PlanNodeType::BatchExchange {
            // Do not visit next stage.
            return Ok(None);
        }
        if let Some(scan_node) = node.as_batch_sys_seq_scan() {
            let name = scan_node.core().table_name.to_owned();
            Ok(Some(TableScanInfo::system_table(name)))
        } else if let Some(scan_node) = node.as_batch_log_seq_scan() {
            build_table_scan_info(
                scan_node.core().table_name.to_owned(),
                &scan_node.core().table_desc,
                &[],
            )
        } else if let Some(scan_node) = node.as_batch_seq_scan() {
            build_table_scan_info(
                scan_node.core().table_name.to_owned(),
                &scan_node.core().table_desc,
                scan_node.scan_ranges(),
            )
        } else {
            node.inputs()
                .into_iter()
                .find_map(|n| self.collect_stage_table_scan(n).transpose())
                .transpose()
        }
    }

    /// Returns the dml table id if any.
    fn collect_dml_table_id(node: &PlanRef) -> Option<TableId> {
        if node.node_type() == PlanNodeType::BatchExchange {
            return None;
        }
        if let Some(insert) = node.as_batch_insert() {
            Some(insert.core.table_id)
        } else if let Some(update) = node.as_batch_update() {
            Some(update.core.table_id)
        } else if let Some(delete) = node.as_batch_delete() {
            Some(delete.core.table_id)
        } else {
            node.inputs()
                .into_iter()
                .find_map(|n| Self::collect_dml_table_id(&n))
        }
    }

    fn collect_stage_lookup_join_parallelism(
        &self,
        node: PlanRef,
    ) -> SchedulerResult<Option<usize>> {
        if node.node_type() == PlanNodeType::BatchExchange {
            // Do not visit next stage.
            return Ok(None);
        }
        if let Some(lookup_join) = node.as_batch_lookup_join() {
            let table_desc = lookup_join.right_table_desc();
            let table_catalog = self
                .catalog_reader
                .read_guard()
                .get_any_table_by_id(&table_desc.table_id)
                .cloned()
                .map_err(RwError::from)?;
            let vnode_mapping = self
                .worker_node_manager
                .fragment_mapping(table_catalog.fragment_id)?;
            let parallelism = vnode_mapping.iter().sorted().dedup().count();
            Ok(Some(parallelism))
        } else {
            node.inputs()
                .into_iter()
                .find_map(|n| self.collect_stage_lookup_join_parallelism(n).transpose())
                .transpose()
        }
    }
}

/// Try to derive the partition to read from the scan range.
/// It can be derived if the value of the distribution key is already known.
fn derive_partitions(
    scan_ranges: &[ScanRange],
    table_desc: &TableDesc,
    vnode_mapping: &WorkerSlotMapping,
) -> SchedulerResult<HashMap<WorkerSlotId, TablePartitionInfo>> {
    let vnode_mapping = if table_desc.vnode_count != vnode_mapping.len() {
        // The vnode count mismatch occurs only in special cases where a hash-distributed fragment
        // contains singleton internal tables. e.g., the state table of `Source` executors.
        // In this case, we reduce the vnode mapping to a single vnode as only `SINGLETON_VNODE` is used.
        assert!(
            table_desc.vnode_count == 1,
            "fragment vnode count {} does not match table vnode count {}",
            vnode_mapping.len(),
            table_desc.vnode_count,
        );
        &WorkerSlotMapping::new_single(vnode_mapping.iter().next().unwrap())
    } else {
        vnode_mapping
    };
    let vnode_count = vnode_mapping.len();

    let mut partitions: HashMap<WorkerSlotId, (BitmapBuilder, Vec<_>)> = HashMap::new();

    if scan_ranges.is_empty() {
        return Ok(vnode_mapping
            .to_bitmaps()
            .into_iter()
            .map(|(k, vnode_bitmap)| {
                (
                    k,
                    TablePartitionInfo {
                        vnode_bitmap,
                        scan_ranges: vec![],
                    },
                )
            })
            .collect());
    }

    let table_distribution = TableDistribution::new_from_storage_table_desc(
        Some(Bitmap::ones(vnode_count).into()),
        &table_desc.try_to_protobuf()?,
    );

    for scan_range in scan_ranges {
        let vnode = scan_range.try_compute_vnode(&table_distribution);
        match vnode {
            None => {
                // put this scan_range to all partitions
                vnode_mapping.to_bitmaps().into_iter().for_each(
                    |(worker_slot_id, vnode_bitmap)| {
                        let (bitmap, scan_ranges) = partitions
                            .entry(worker_slot_id)
                            .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
                        vnode_bitmap
                            .iter()
                            .enumerate()
                            .for_each(|(vnode, b)| bitmap.set(vnode, b));
                        scan_ranges.push(scan_range.to_protobuf());
                    },
                );
            }
            // scan a single partition
            Some(vnode) => {
                let worker_slot_id = vnode_mapping[vnode];
                let (bitmap, scan_ranges) = partitions
                    .entry(worker_slot_id)
                    .or_insert_with(|| (BitmapBuilder::zeroed(vnode_count), vec![]));
                bitmap.set(vnode.to_index(), true);
                scan_ranges.push(scan_range.to_protobuf());
            }
        }
    }

    Ok(partitions
        .into_iter()
        .map(|(k, (bitmap, scan_ranges))| {
            (
                k,
                TablePartitionInfo {
                    vnode_bitmap: bitmap.finish(),
                    scan_ranges,
                },
            )
        })
        .collect())
}

#[cfg(test)]
mod tests {
    use std::collections::{HashMap, HashSet};

    use risingwave_pb::batch_plan::plan_node::NodeBody;

    use crate::optimizer::plan_node::PlanNodeType;
    use crate::scheduler::plan_fragmenter::StageId;

    #[tokio::test]
    async fn test_fragmenter() {
        let query = crate::scheduler::distributed::tests::create_query().await;

        assert_eq!(query.stage_graph.root_stage_id, 0);
        assert_eq!(query.stage_graph.stages.len(), 4);

        // Check the mappings of child edges.
        assert_eq!(query.stage_graph.child_edges[&0], [1].into());
        assert_eq!(query.stage_graph.child_edges[&1], [2, 3].into());
        assert_eq!(query.stage_graph.child_edges[&2], HashSet::new());
        assert_eq!(query.stage_graph.child_edges[&3], HashSet::new());

        // Check the mappings of parent edges.
        assert_eq!(query.stage_graph.parent_edges[&0], HashSet::new());
        assert_eq!(query.stage_graph.parent_edges[&1], [0].into());
        assert_eq!(query.stage_graph.parent_edges[&2], [1].into());
        assert_eq!(query.stage_graph.parent_edges[&3], [1].into());

        // Verify topology order
        {
            let stage_id_to_pos: HashMap<StageId, usize> = query
                .stage_graph
                .stage_ids_by_topo_order()
                .enumerate()
                .map(|(pos, stage_id)| (stage_id, pos))
                .collect();

            for stage_id in query.stage_graph.stages.keys() {
                let stage_pos = stage_id_to_pos[stage_id];
                for child_stage_id in &query.stage_graph.child_edges[stage_id] {
                    let child_pos = stage_id_to_pos[child_stage_id];
                    assert!(stage_pos > child_pos);
                }
            }
        }

        // Check plan node in each stages.
        let root_exchange = query.stage_graph.stages.get(&0).unwrap();
        assert_eq!(root_exchange.root.node_type(), PlanNodeType::BatchExchange);
        assert_eq!(root_exchange.root.source_stage_id, Some(1));
        assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_)));
        assert_eq!(root_exchange.parallelism, Some(1));
        assert!(!root_exchange.has_table_scan());

        let join_node = query.stage_graph.stages.get(&1).unwrap();
        assert_eq!(join_node.root.node_type(), PlanNodeType::BatchHashJoin);
        assert_eq!(join_node.parallelism, Some(24));

        assert!(matches!(join_node.root.node, NodeBody::HashJoin(_)));
        assert_eq!(join_node.root.source_stage_id, None);
        assert_eq!(2, join_node.root.children.len());

        assert!(matches!(
            join_node.root.children[0].node,
            NodeBody::Exchange(_)
        ));
        assert_eq!(join_node.root.children[0].source_stage_id, Some(2));
        assert_eq!(0, join_node.root.children[0].children.len());

        assert!(matches!(
            join_node.root.children[1].node,
            NodeBody::Exchange(_)
        ));
        assert_eq!(join_node.root.children[1].source_stage_id, Some(3));
        assert_eq!(0, join_node.root.children[1].children.len());
        assert!(!join_node.has_table_scan());

        let scan_node1 = query.stage_graph.stages.get(&2).unwrap();
        assert_eq!(scan_node1.root.node_type(), PlanNodeType::BatchSeqScan);
        assert_eq!(scan_node1.root.source_stage_id, None);
        assert_eq!(0, scan_node1.root.children.len());
        assert!(scan_node1.has_table_scan());

        let scan_node2 = query.stage_graph.stages.get(&3).unwrap();
        assert_eq!(scan_node2.root.node_type(), PlanNodeType::BatchFilter);
        assert_eq!(scan_node2.root.source_stage_id, None);
        assert_eq!(1, scan_node2.root.children.len());
        assert!(scan_node2.has_table_scan());
    }
}
